Implement cutlass_fused_moe mxfp8#2581
Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new mixed-precision quantization scheme, MXFP8xMXFP8, for the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for mxfp8 x mxfp8 quantization in the cutlass_fused_moe kernel. The changes are extensive, touching kernel implementations, bindings, and adding new tests to validate the functionality. The implementation correctly adds new code paths for mxfp8 handling, including new quantization modes and scaling types.
My review focuses on improving code maintainability by addressing several instances of code duplication. I've identified repeated logic for determining scaling types and for dispatching kernels, and I've suggested refactoring these into helper functions or using other C++ patterns to reduce redundancy. These changes should make the code cleaner and easier to maintain in the future.
| auto fpX_scaling_type = getScalingType(); | ||
| if constexpr (use_fp8) { | ||
| if (use_mxfp8_fp8_block_scaling) { | ||
| fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; | ||
| } | ||
| } |
There was a problem hiding this comment.
This logic to determine fpX_scaling_type is duplicated in configureWsPtrs (lines 2755-2760) and setupTmaWarpSpecializedInputs (lines 3933-3938). To improve maintainability and reduce code duplication, consider extracting this logic into a private helper function within the CutlassMoeFCRunner class.
For example:
__host__ __device__ inline TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType
getFpXScalingTypeHelper(bool use_mxfp8_fp8_block_scaling) const {
auto fpX_scaling_type = getScalingType();
if constexpr (use_fp8) {
if (use_mxfp8_fp8_block_scaling) {
fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
}
}
return fpX_scaling_type;
}Then, you can call this helper function in all three places.
|
/bot run |
|
@flashinfer-bot run |
|
[FAILED] Pipeline #44336503: 7/20 passed |
📝 WalkthroughWalkthroughAdds MXFPX/MXFP8 activation-scaling and block-scaling support across the Cutlass MoE backend: new template parameter Changes
Sequence Diagram(s)sequenceDiagram
participant Binding as Host Binding
participant Runner as CutlassMoeFCRunner<IsMXFPX>
participant MoeGemm as MoeGemmRunner<IsMXFPX>
participant Device as GPU (CUTLASS / TMA)
Binding->>Runner: prepare inputs, quant params, use_mxfp8_act_scaling
Runner->>MoeGemm: get workspace sizes, configureWsPtrs, select kernel
MoeGemm->>Device: dispatch GEMM/TMA (MXFPX or non‑MXFP path)
Device-->>MoeGemm: kernel results
MoeGemm-->>Runner: post-process (accumulate/gate/dequant)
Runner-->>Binding: final outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)
1389-1392: GPU architecture skip should useflashinfer.utilsfunctions per coding guidelines.The
@pytest.mark.skipifusestorch.cuda.get_device_capability()directly. The same pattern is used elsewhere in this file (e.g., lines 481-484, 1256-1258), but the coding guidelines require usingflashinfer.utilsfunctions likeget_compute_capabilityoris_sm90a_supportedfor architecture-gated skips. Consider updating all new tests to follow this guideline.As per coding guidelines:
tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.#!/bin/bash # Check what flashinfer.utils functions are available for capability checks rg -n "def get_compute_capability\|def is_sm" --type=py -g '*/utils*' # Check if any test files in the repo already follow the guideline rg -n "from flashinfer.utils import\|flashinfer.utils.get_compute" --type=py -g 'tests/**'🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1389 - 1392, The skip decorator in the test uses torch.cuda.get_device_capability() directly; replace it with the utility helpers from flashinfer.utils (e.g., import and call get_compute_capability() or the appropriate is_smXX_supported helper) so GPU-architecture gating follows project guidelines; update the `@pytest.mark.skipif` on the test (around the decorator that currently checks torch.cuda.get_device_capability()) to call flashinfer.utils.get_compute_capability() or is_sm90a_supported/is_sm100_supported as appropriate and mirror the pattern used in other tests in this file.csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (1)
889-900: fc1_weight_block validation always assumes gated (×2) activation — correct for current usage but potentially fragile.The check at line 894 unconditionally multiplies
inter_sizeby 2 for the fc1 N-dimension. This is correct for SwiGLU/SwigluBias (gated activations), which is the only activation type currently used with MXFP8. However, the NVFP4 path (further below) conditionally handles both gated and non-gated cases viaisGatedActivation(base_activation_type).If non-gated activations are supported for MXFP8 in the future, this check will need to be updated.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu` around lines 889 - 900, The fc1_weight_block size check unconditionally assumes a gated (×2) N-dimension; update the validation in the block that references fc1_weight_block, TmaWarpSpecializedGroupedGemmInput::alignToSfDim, and inter_size to only multiply inter_size by 2 when the activation is gated (use isGatedActivation(base_activation_type) or equivalent), otherwise use inter_size as-is, and adjust the error message to reflect both gated and non-gated expected shapes so the check matches the NVFP4 branch behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh`:
- Around line 3578-3586: The current gate use_mxfp8_weight_block_scales only
checks fc1.weight_block_scale and can still dereference fc2.global_scale; change
the logic to enable MXFP8 block-scale mode only when fp8_scales_required
indicates MXFP8 (same flag used elsewhere) AND both
quant_params.mxfp8_mxfp4.fc1.weight_block_scale and
quant_params.mxfp8_mxfp4.fc2.weight_block_scale are set and both corresponding
global_scale pointers are non-null; then use that boolean to choose
fc1_fp8_dequant and fc2_fp8_dequant (otherwise fall back to quant_params.fp8.*
dequant pointers). Ensure the symbol names mentioned
(use_mxfp8_weight_block_scales, quant_params.mxfp8_mxfp4.fc1.weight_block_scale,
quant_params.mxfp8_mxfp4.fc2.weight_block_scale,
quant_params.mxfp8_mxfp4.fc1.global_scale,
quant_params.mxfp8_mxfp4.fc2.global_scale, fp8_scales_required, fc1_fp8_dequant,
fc2_fp8_dequant) are used to locate and update the condition and selections.
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1448-1462: Add a GPU-architecture skip to the test function
test_moe_mxfp8_mxfp8 by decorating it with pytest.mark.skipif using the
flashinfer.utils helpers: call get_compute_capability() and pass it to
is_sm90a_supported (or the appropriate helper for MXFP8 support) and skip when
those indicate unsupported hardware; ensure you import pytest and the helpers
(get_compute_capability, is_sm90a_supported) and place the skip decorator
immediately above the test_moe_mxfp8_mxfp8 definition so the test is skipped on
unsupported GPUs.
---
Duplicate comments:
In `@csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh`:
- Around line 4108-4117: This is a duplicate of the MXFP8 validation/gating code
— remove the redundant block and consolidate the FP8 gating logic so only the
earlier validated ternary selection for dequant scales is used; specifically,
keep a single place that computes the fc1/fc2 dequant scale (the ternaries using
std::is_same_v<WeightType, __nv_fp8_e4m3> and
quant_params.mxfp8_mxfp4.fc?.global_scale vs quant_params.fp8.dequant_fc?) and
pass those same values along with fc1_expert_weights, fc2_expert_weights,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases,
fc2_bias to the kernel — delete or merge the duplicated lines to avoid repeated
gating/validation.
---
Nitpick comments:
In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`:
- Around line 889-900: The fc1_weight_block size check unconditionally assumes a
gated (×2) N-dimension; update the validation in the block that references
fc1_weight_block, TmaWarpSpecializedGroupedGemmInput::alignToSfDim, and
inter_size to only multiply inter_size by 2 when the activation is gated (use
isGatedActivation(base_activation_type) or equivalent), otherwise use inter_size
as-is, and adjust the error message to reflect both gated and non-gated expected
shapes so the check matches the NVFP4 branch behavior.
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1389-1392: The skip decorator in the test uses
torch.cuda.get_device_capability() directly; replace it with the utility helpers
from flashinfer.utils (e.g., import and call get_compute_capability() or the
appropriate is_smXX_supported helper) so GPU-architecture gating follows project
guidelines; update the `@pytest.mark.skipif` on the test (around the decorator
that currently checks torch.cuda.get_device_capability()) to call
flashinfer.utils.get_compute_capability() or
is_sm90a_supported/is_sm100_supported as appropriate and mirror the pattern used
in other tests in this file.
|
Hopper still failing. Can confirm |
|
Can confirm H100 also passes: |
| fc1_weight_block.size(0) == num_experts_on_rank && | ||
| fc1_weight_block.size(1) == | ||
| TmaWarpSpecializedGroupedGemmInput::alignToSfDim( | ||
| inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * |
There was a problem hiding this comment.
Is this only supposed to work with gated activations? The bf16 variant of this kernel supports both gated and non-gated activations.
There was a problem hiding this comment.
It works with gating, reference the unit test here: https://github.com/zianglih/flashinfer/blob/aba577ad95f7998b46616dd5c0fa7f8b1818f717/tests/moe/test_trtllm_cutlass_fused_moe.py#L1417-L1418
Also I have tried this kernel in SGLang sgl-project/sglang#18945 and can run Qwen3-30B-A3B without problems.
There was a problem hiding this comment.
My question was for non gated activations like squared relu. Does it work with them? I tested yesterday and it did not.
|
hi i locally repro the H100 errors |
|
this was the error i am seeing |
|
Hi @aleozlx , this is exactly the ci error. Have you tried clear any cache? Thanks! |
|
this is a fresh container, i don't think there is anything to clear? |
|
i switch to main branch, they passed |
|
now trying running tests after |
|
I'll also retry on Hopper |
|
Hi @aleozlx, rerunning both Blackwell and Hopper now, Blackwell also shows the failure now. Let me debug. |
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
406-516:⚠️ Potential issue | 🔴 CriticalVerify that explicit instantiation TUs have matching IsMXFPX=false variants for all SM versions and type combinations now being dispatched.
The lambda
dispatch_by_mxfpxis instantiated twice at runtime (lines 513, 515) with bothIsMXFPX=trueandIsMXFPX=false, causing the nested SHAPE_CASE macros to generate references totma_warp_specialized_generic_moe_gemm_kernelLauncher<>with both boolean variants across all supported SM versions (SM90, SM100, SM120). This widens the linker surface: legacy FP8 paths now bakeIsMXFPX=falseinto symbols that may not have explicit instantiations in the compilation units.If
moe_gemm_kernels_fp8_fp8.cuand related type-specific TUs do not emit matchingIsMXFPX=falsevariants for all the tile/cluster shapes and epilogue configurations now being dispatched, this change is the root cause of the undefined symbol errors infused_moe_90.so.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h` around lines 406 - 516, The change instantiates dispatch_by_mxfpx with both IsMXFPX=true and false which causes references to tma_warp_specialized_generic_moe_gemm_kernelLauncher<> (via dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized and the SHAPE_CASE macro) for SM90/100/120 even when no explicit TU instantiations exist; fix by adding matching explicit instantiations for the IsMXFPX=false variants (all relevant template parameter combinations: T, WeightType, EpilogueTag, FUSION and all tile shapes used in SHAPE_CASE for SM90/100/120) into the corresponding type-specific TUs (e.g., the FP8/FP4 kernel TUs), or alternately narrow the runtime dispatch so dispatch_by_mxfpx is only instantiated for the boolean value that actually has corresponding explicit instantiations to avoid generating undefined symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 491-504: The SM120 branch can silently fall through when
kernels::cutlass_kernels::isValidSM120MOESpecialisation<T,WeightType,EpilogueTag,FUSION>()
is false causing count==0 and a no-op; update the else path for the if constexpr
in the gemm_config.sm_version == 120/121 block to fail loudly (throw an
exception or call a fatal logger) with a clear message referencing SM120 and the
tile_config_sm120 value (use pretty_function or include
gemm_config.tile_config_sm120 in the message) so invalid SM120 specialisations
cannot proceed; ensure the thrown exception type matches surrounding error
handling conventions.
---
Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 406-516: The change instantiates dispatch_by_mxfpx with both
IsMXFPX=true and false which causes references to
tma_warp_specialized_generic_moe_gemm_kernelLauncher<> (via
dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized and the SHAPE_CASE macro)
for SM90/100/120 even when no explicit TU instantiations exist; fix by adding
matching explicit instantiations for the IsMXFPX=false variants (all relevant
template parameter combinations: T, WeightType, EpilogueTag, FUSION and all tile
shapes used in SHAPE_CASE for SM90/100/120) into the corresponding type-specific
TUs (e.g., the FP8/FP4 kernel TUs), or alternately narrow the runtime dispatch
so dispatch_by_mxfpx is only instantiated for the boolean value that actually
has corresponding explicit instantiations to avoid generating undefined symbols.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0ec1b1ce-bd64-4a88-a347-ddfc149c8f83
📒 Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
| } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { | ||
| char const* const pretty_function = __PRETTY_FUNCTION__; | ||
| TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, | ||
| (int)(gemm_config.tile_config_sm120)); | ||
| if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< | ||
| T, WeightType, EpilogueTag, FUSION>()) { | ||
| switch (gemm_config.tile_config_sm120) { | ||
| SHAPE_CASE(120, 128, 128, 64) | ||
| SHAPE_CASE(120, 128, 128, 128) | ||
| SHAPE_CASE(120, 128, 256, 64) | ||
| SHAPE_CASE(120, 256, 128, 64) | ||
| DEFAULT_CASE(120) | ||
| } | ||
| } |
There was a problem hiding this comment.
Throw on invalid SM120 specialisations instead of silently falling through.
If isValidSM120MOESpecialisation<...>() is false, this branch currently exits without a throw. In the workspace path that can leave count == 0, and in the execution path it becomes a no-op instead of a hard failure.
💡 Proposed fix
} else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) {
char const* const pretty_function = __PRETTY_FUNCTION__;
TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function,
(int)(gemm_config.tile_config_sm120));
if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation<
T, WeightType, EpilogueTag, FUSION>()) {
switch (gemm_config.tile_config_sm120) {
SHAPE_CASE(120, 128, 128, 64)
SHAPE_CASE(120, 128, 128, 128)
SHAPE_CASE(120, 128, 256, 64)
SHAPE_CASE(120, 256, 128, 64)
DEFAULT_CASE(120)
}
+ } else {
+ TLLM_THROW("Unsupported SM120 configuration requested");
}
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { | |
| char const* const pretty_function = __PRETTY_FUNCTION__; | |
| TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, | |
| (int)(gemm_config.tile_config_sm120)); | |
| if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< | |
| T, WeightType, EpilogueTag, FUSION>()) { | |
| switch (gemm_config.tile_config_sm120) { | |
| SHAPE_CASE(120, 128, 128, 64) | |
| SHAPE_CASE(120, 128, 128, 128) | |
| SHAPE_CASE(120, 128, 256, 64) | |
| SHAPE_CASE(120, 256, 128, 64) | |
| DEFAULT_CASE(120) | |
| } | |
| } | |
| } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { | |
| char const* const pretty_function = __PRETTY_FUNCTION__; | |
| TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, | |
| (int)(gemm_config.tile_config_sm120)); | |
| if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< | |
| T, WeightType, EpilogueTag, FUSION>()) { | |
| switch (gemm_config.tile_config_sm120) { | |
| SHAPE_CASE(120, 128, 128, 64) | |
| SHAPE_CASE(120, 128, 128, 128) | |
| SHAPE_CASE(120, 128, 256, 64) | |
| SHAPE_CASE(120, 256, 128, 64) | |
| DEFAULT_CASE(120) | |
| } | |
| } else { | |
| TLLM_THROW("Unsupported SM120 configuration requested"); | |
| } | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`
around lines 491 - 504, The SM120 branch can silently fall through when
kernels::cutlass_kernels::isValidSM120MOESpecialisation<T,WeightType,EpilogueTag,FUSION>()
is false causing count==0 and a no-op; update the else path for the if constexpr
in the gemm_config.sm_version == 120/121 block to fail loudly (throw an
exception or call a fatal logger) with a clear message referencing SM120 and the
tile_config_sm120 value (use pretty_function or include
gemm_config.tile_config_sm120 in the message) so invalid SM120 specialisations
cannot proceed; ensure the thrown exception type matches surrounding error
handling conventions.
|
The previous Blackwell failure I saw was due to stale ninja file, irrelevant to this PR. The Hoppper failure is a real regression, now fixed by 6ab67b4 . Blackwell: Hopper: |
|
/bot run |
|
@flashinfer-bot run |
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
) Fixes #2731. ## What's broken? When using the CUTLASS fused MoE backend with **non-gated activations** (e.g., Relu2, Gelu, Silu) and MXFP8 quantization, the fc1 weight shape validation unconditionally rejects the input — even when the shape is correct. ## Who is affected? Anyone using the **CUTLASS fused MoE** path with: - **Quantization**: `WMxfp8AMxfp8`, `WMxfp4AFp8`, or `WMxfp4AMxfp8` - **Activation**: any non-gated type (Relu2, Gelu, Silu, etc.) Not affected: gated activations (Swiglu, Geglu, SwigluBias), or other quant modes (NVFP4 already handles this correctly). ## Where is the bug? `csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`, inside `getQuantParams()` — the fc1 weight block N-dimension check hardcodes `* 2` at three MXFP8 branches (~L898, ~L1004, ~L1063). ## Why does it happen? PR #2581 introduced MXFP8 support when only gated activations (Swiglu) existed, so `inter_size * 2` was correct. Later, non-gated activation support was added to the trtllm-gen backend (PR #2707), but the CUTLASS backend's validation was never updated. The NVFP4 path in the same file (line ~1131) already handles this correctly with an `if (isGatedActivation(...))` guard. ## How did we fix it? For each of the 3 MXFP8 quant branches: 1. Extract `int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1;` 2. Replace the hardcoded `* 2` with `* fc1_n_mult` 3. Update error messages: gated shows `"inter_size * 2"`, non-gated shows `"inter_size"` **Before:** ```cpp fc1_weight_block.size(1) == alignToSfDim(inter_size, ...) * 2 ``` **After:** ```cpp int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1; fc1_weight_block.size(1) == alignToSfDim(inter_size, ...) * fc1_n_mult ``` ## How do we know it works? - `pre-commit run` passes (clang-format, lint, etc.) - Gated activations (default Swiglu): `fc1_n_mult = 2` — identical to old behavior, no regression - Non-gated activations: `fc1_n_mult = 1` — shape check now accepts correct `inter_size` dimension - Full GPU test suite requires CI (`@flashinfer-bot run`) ## Related - Builds on the approach identified in #2753 (stale ~27 days, CI unresolved). - Addresses the Gemini review feedback from #2753 by extracting the multiplier to a local variable before the validation checks. cc @aleozlx @nv-yunzheq <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed weight block size validation for Mixture of Experts (MOE) to correctly handle both gated and non-gated activation types, ensuring proper support across different activation configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Yiyang Liu <37043548+ianliuy@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

📌 Description
@HumansAnd
#2505 implements mxfp8 for trtllm backend.
However, in SGLang,
--moe-runner-backend flashinfer_trtllmbypasses SGLang topk implementation and does not work with expert routing replay in MoE RL.We want to implement
mxfp8 x mxfp8forcutlass_fused_moewhich works with MoE RL training.This PR mainly reuses existing code path for
WMxfp4AMxfp8Quant:flashinfer/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu
Line 1191 in 952b6ab
🔍 Related Issues
miles MXFP8/NVFP4 RL roadmap: radixark/miles#615
SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests